#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 22 11:19:45 2019

@author: Macrobull

Not all ops in this file are supported by both PyTorch and ONNX
This only demostrates the conversion/validation workflow from PyTorch to ONNX to Paddle fluid
"""

from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F

from onnx2fluid.torch_export_helper import export_onnx_with_validation

prefix = 'sample_'
idx = 0

######## example: RNN cell ########


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.gru = nn.GRUCell(6, 5)
        self.lstm = nn.LSTMCell(5, 4)

    def forward(self, x, h1, h2, c2):
        h = self.gru(x, h1)
        h, c = self.lstm(h, (h2, c2))
        return h, c


model = Model()
model.eval()
xb = torch.rand((7, 6))
h1 = torch.zeros((7, 5))
h2 = torch.zeros((7, 4))
c2 = torch.zeros((7, 4))
yp = model(xb, h1, h2, c2)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, [xb, h1, h2, c2],
                            prefix + str(idx), ['x', 'h1', 'h2', 'c2'],
                            ['h', 'c'],
                            verbose=True,
                            training=False)

######## example: RNN ########


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.gru = nn.GRU(6, 5, 3)
        self.lstm = nn.LSTM(5, 4, 2)

    def forward(self, x, h1, h2, c2):
        y, h1 = self.gru(x, h1)
        y, (h2, c2) = self.lstm(y, (h2, c2))
        return y


model = Model()
model.eval()
xb = torch.rand((8, 1, 6))
h1 = torch.zeros((3, 1, 5))
h2 = torch.zeros((2, 1, 4))
c2 = torch.zeros((2, 1, 4))
yp = model(xb, h1, h2, c2)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, [xb, h1, h2, c2],
                            prefix + str(idx), ['x', 'h1', 'h2', 'c2'], ['y'],
                            verbose=True,
                            training=False)

######## example: random ########
"""
    symbolic registration:

    def rand(g, *shapes):
        shapes_list = list(shapes)
        shape = _maybe_get_const(shapes_list[0], "is")
        return g.op('RandomUniform', shape_i=shape)
"""


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, x):
        y = torch.rand((2, 3))  # + torch.rand_like(x)
        y = y + torch.randn((2, 3))  # + torch.randn_like(x)
        y = y + x
        return y


model = Model()
model.eval()
xb = torch.rand((2, 3))
yp = model(xb)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, [xb],
                            prefix + str(idx), ['x'], ['y'],
                            verbose=True,
                            training=False)

######## example: fc ########


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc = nn.Linear(3, 8)

    def forward(self, x):
        y = x
        y = self.fc(y)
        return y


model = Model()
model.eval()
xb = torch.rand((2, 3))
yp = model(xb)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, [xb],
                            prefix + str(idx), ['x'], ['y'],
                            verbose=True,
                            training=False)

######## example: compare ########


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, x0, x1):
        x0 = x0.clamp(-1, 1)
        a = torch.max(x0, x1) == x1
        b = x0 < x1
        c = x0 > x1
        return a, b, c


model = Model()
model.eval()
xb0 = torch.rand((2, 3))
xb1 = torch.rand((2, 3))
ya, yb, yc = model(xb0, xb1)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, [xb0, xb1],
                            prefix + str(idx), ['x0', 'x1'], ['ya', 'yb', 'yc'],
                            verbose=True,
                            training=False)

######## example: affine_grid ########
"""
    symbolic registration:

    @parse_args('v', 'is')
    def affine_grid_generator(g, theta, size):
        return g.op('AffineGrid', theta, size_i=size)
"""


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, theta):
        grid = F.affine_grid(theta, (2, 2, 8, 8))
        return grid


model = Model()
model.eval()
theta = torch.rand((2, 2, 3))
grid = model(theta)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, (theta, ),
                            prefix + str(idx), ['theta'], ['grid'],
                            verbose=True,
                            training=False)

######## example: conv2d_transpose ########


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv = nn.ConvTranspose2d(3, 8, 3)
        self.dropout = nn.Dropout2d()

    def forward(self, x):
        y = x
        y = self.conv(y)
        y = self.dropout(y)
        return y


model = Model()
model.eval()
xb = torch.rand((2, 3, 4, 5))
yp = model(xb)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, [xb],
                            prefix + str(idx), ['x'], ['y'],
                            verbose=True,
                            training=False)

######## example: conv2d ########


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv = nn.Conv2d(3, 8, 3)
        self.batch_norm = nn.BatchNorm2d(8)
        self.pool = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        y = x
        y = self.conv(y)
        y = self.batch_norm(y)
        y = self.pool(y)
        return y


model = Model()
model.eval()
xb = torch.rand((2, 3, 4, 5))
yp = model(xb)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, [xb],
                            prefix + str(idx), ['x'], ['y'],
                            verbose=True,
                            training=False)

######### example: conv1d ########
#
#class Model(nn.Module):
#    def __init__(self):
#        super(Model, self).__init__()
#        self.batch_norm = nn.BatchNorm2d(3)
#
#    def forward(self, x):
#        y = x
#        y = self.batch_norm(y)
#        return y
#
#
#model = Model()
#model.eval()
#xb = torch.rand((2, 3, 4, 5))
#yp = model(xb)
#idx += 1
#print('index: ', idx)
#export_onnx_with_validation(
#        model, [xb], prefix + str(idx),
#        ['x'], ['y'],
#        verbose=True, training=False)

######## example: empty ########


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, x):
        return x


model = Model()
model.eval()
xb = torch.rand((2, 3))
yp = model(xb)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, [xb],
                            prefix + str(idx), ['y'], ['y'],
                            verbose=True,
                            training=False)
